# Learning-to-Focus

> Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning

---

## Table of Contents

- [Overview](#overview)  
- [Features](#features)  
- [Dependencies](#dependencies)  
- [Installation](#installation)  
- [Repository Structure](#repository-structure)  
- [Usage](#usage)  
  - [1. Gradient-Based Comparison](#1-gradient-based-comparison)  
  - [2. Remove Misleading Patterns](#2-remove-misleading-patterns)  
  - [3. Causal Attention Distillation](#3-causal-attention-distillation)  
- [License](#license)

---

## Overview

This repository implements the methods described in **Learning to Focus: Causal Attention Distillation via Gradient-Guided Token Pruning**. We present a two-stage framework:

1. **Misleading Token Detection**  
   Identify spurious “confounder” tokens in the training corpus via gradient comparisons between a high-capacity teacher model and a student model.  
2. **Causal Attention Distillation (CAD)**  
   Prune detected confounders and align the student model’s attention distributions to the teacher’s, capturing true causal dependencies and improving reasoning robustness.

---

## Features

- 🔍 **Misleading Token Detection**  
  - Instruction-level and response-level gradient comparisons between teacher and student models.  
- ✂️ **Remove Misleading Patterns**  
  - Prune identified confounders in code and math corpora.  
- 🎯 **Causal Attention Distillation**  
  - Minimize KL divergence between teacher and student attention distributions at both instruction and response levels.

---

## Dependencies

- Python 3.10+  
- PyTorch  
- Hugging Face Transformers  

See [requirements.txt](./requirements.txt) for full details.

---

## Installation

```bash
pip install -r requirements.txt
```

## Repository Structure

```text
Code_for_Learning_to_Focus/
├── data
│   ├── llama_1b_instruct_level
│   │   ├── Distill_NuminaMATH_llama_1b_misleading_0.10.json
│   │   └── NuminaMATH_eval_1035_samples.json
│   ├── llama_1b_response_level
│   │   ├── Distill_NuminaMath_llama_1b_misleading_step_0.075.json
│   │   └── Numina_eval_1035_with_step_response.json
│   └── llama_1b_gradient_data
│       └── Numina_train_data_llama_1b_1.2w.json
├── scripts
│   ├── Misleading_pattern_detection
│   │   ├── gradient_comparison
│   │   │   ├── run_instruct_level_llama.sh
│   │   │   └── run_instruct_level_qwen.sh
│   │   └── remove_misleading_patterns
│   │       ├── run_code.sh
│   │       └── run_math.sh
│   └── Causal_attention_distillation
│       ├── instruct_level_distillation
│       │   ├── run_llama3.2_1b_instruct.sh
│       │   ├── run_llama3.2_3b_instruct.sh
│       │   └── run_qwen2.5_1.5b_math.sh
│       └── response_level_distillation
│           ├── run_llama3.2_1b_response.sh
│           ├── run_llama3.2_3b_response.sh
│           └── run_qwen2.5_1.5b_response.sh
├── src
│   ├── Causal_attention_distillation
│   │   ├── CAD_instruct_level
│   │   │   └── distill_instruct_level.py
│   │   └── CAD_response_level
│   │       └── distill_response_level.py
│   └── Misleading_pattern_detection
│       ├── run_gradients_llama_instruct_level.py
│       ├── run_gradients_llama_response_level.py
│       ├── run_gradients_qwen_instruct_level.py
│       ├── run_gradients_qwen_response_level.py
│       ├── remove_misleading_patterns_code.py
│       └── remove_misleading_patterns_math.py
├── requirements.txt
└── README.md

```

## Usage

### 1. Gradient-Based Comparison

Perform gradient-based misleading token detection.

- **Instruction-Level**  
    ```
    bash scripts/Misleading_pattern_detection/gradient_comparison/run_instruct_level_llama.sh  
    bash scripts/Misleading_pattern_detection/gradient_comparison/run_instruct_level_qwen.sh
    ```

- **Response-Level**  
    ```
    bash scripts/Misleading_pattern_detection/gradient_comparison/run_response_level_llama.sh  
    bash scripts/Misleading_pattern_detection/gradient_comparison/run_response_level_qwen.sh
    ```

### 2. Remove Misleading Patterns

Prune identified confounders in the training corpus.

- **Code Corpus**  
    ```
    bash scripts/Misleading_pattern_detection/remove_misleading_patterns/run_code.sh
    ```

- **Math Corpus**  
    ```
    bash scripts/Misleading_pattern_detection/remove_misleading_patterns/run_math.sh
    ```

### 3. Causal Attention Distillation

Align student attention to teacher attention.

- **Instruction-Level Distillation**  
    ```
    bash scripts/Causal_attention_distillation/instruct_level_distillation/run_llama3.2_1b_instruct.sh  
    bash scripts/Causal_attention_distillation/instruct_level_distillation/run_llama3.2_3b_instruct.sh  
    bash scripts/Causal_attention_distillation/instruct_level_distillation/run_qwen2.5_1.5b_response.sh
    ```

- **Instruction + Response-Level Distillation**  
    ```
    bash scripts/Causal_attention_distillation/response_level_distillation/run_llama3.2_1b_response.sh  
    bash scripts/Causal_attention_distillation/response_level_distillation/run_llama3.2_3b_response.sh  
    bash scripts/Causal_attention_distillation/response_level_distillation/run_qwen2.5_1.5b_response.sh
    ```

